import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.util import im2col

if __name__ == "__main__":
    x1 = np.random.rand(1, 3, 7, 7)
    col1 = im2col(x1, 5, 5, stride=1, pad=0)
    print(col1.shape) #(9,75)

    x1_1 = np.random.rand(1, 3, 8, 8)
    col1_1 = im2col(x1_1, 5, 5, stride=1, pad=0)
    print(col1_1.shape)  # (16,75)

    '''
        比如说（3，7，7）3通道，高度7，长度7
        对于这个函数iml2col（长度5，高度5，幅度为1，不填充）
        对于一个通道7x7，利用滤波器（5，5），会有9个应用域。
        每个有25个数据，3个就是75个数据，展开成一列，一共是9列
        所以shape是(9,75)
        10个数据就是（90，75）
    '''

    x2 = np.random.rand(10, 3, 7, 7)
    col2 = im2col(x2, 5, 5, stride=1, pad=0)
    print(col2.shape) # (90,75)
